/*
 * Copyright (c) 2023, 2023, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.  Oracle designates this
 * particular file as subject to the "Classpath" exception as provided
 * by Oracle in the LICENSE file that accompanied this code.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */
package com.oracle.svm.hosted.methodhandles;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.lang.reflect.Array;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import com.oracle.graal.pointsto.BigBang;
import com.oracle.svm.util.OriginalClassProvider;
import com.oracle.graal.pointsto.infrastructure.SubstitutionProcessor;
import com.oracle.graal.pointsto.meta.AnalysisType;
import com.oracle.graal.pointsto.meta.BaseLayerType;
import com.oracle.svm.core.SubstrateUtil;
import com.oracle.svm.core.util.BasedOnJDKClass;
import com.oracle.svm.core.util.VMError;
import com.oracle.svm.util.ReflectionUtil;

import jdk.vm.ci.meta.ResolvedJavaType;

/**
 * A substitution processor that renames classes generated by {@code InvokerBytecodeGenerator},
 * which are assigned more or less arbitrary names by the host VM, to stable names that are based on
 * the {@code LambdaForm} which they were compiled from.
 */
@BasedOnJDKClass(value = MethodHandle.class)
@BasedOnJDKClass(value = MethodType.class)
@BasedOnJDKClass(className = "java.lang.invoke.MethodHandleStatics")
@BasedOnJDKClass(className = "java.lang.invoke.ClassSpecializer")
@BasedOnJDKClass(className = "java.lang.invoke.ClassSpecializer", innerClass = "SpeciesData")
@BasedOnJDKClass(className = "java.lang.invoke.MemberName")
@BasedOnJDKClass(className = "java.lang.invoke.MethodHandleNatives")
@BasedOnJDKClass(className = "java.lang.invoke.LambdaForm")
@BasedOnJDKClass(className = "java.lang.invoke.LambdaForm", innerClass = "BasicType")
@BasedOnJDKClass(className = "java.lang.invoke.LambdaForm", innerClass = "Name")
@BasedOnJDKClass(className = "java.lang.invoke.LambdaForm", innerClass = "NamedFunction")
@BasedOnJDKClass(className = "java.lang.invoke.BoundMethodHandle")
@BasedOnJDKClass(className = "java.lang.invoke.DirectMethodHandle")
@BasedOnJDKClass(className = "java.lang.invoke.MethodHandleImpl", innerClass = "IntrinsicMethodHandle")
public class MethodHandleInvokerRenamingSubstitutionProcessor extends SubstitutionProcessor {
    private static final Class<?> METHOD_HANDLE_STATICS_CLASS = ReflectionUtil.lookupClass(false, "java.lang.invoke.MethodHandleStatics");
    private static final Field DEBUG_METHOD_HANDLE_NAMES_FIELD = ReflectionUtil.lookupField(METHOD_HANDLE_STATICS_CLASS, "DEBUG_METHOD_HANDLE_NAMES");
    private static final Class<?> CLASS_SPECIALIZER_CLASS = ReflectionUtil.lookupClass(false, "java.lang.invoke.ClassSpecializer");
    private static final Field CLASS_SPECIALIZER_META_TYPE_FIELD = ReflectionUtil.lookupField(CLASS_SPECIALIZER_CLASS, "metaType");
    private static final Class<?> SPECIES_DATA_CLASS = ReflectionUtil.lookupClass(false, "java.lang.invoke.ClassSpecializer$SpeciesData");
    private static final Method SPECIES_DATA_OUTER_METHOD = ReflectionUtil.lookupMethod(SPECIES_DATA_CLASS, "outer");
    private static final Field SPECIES_DATA_SPECIES_CODE_FIELD = ReflectionUtil.lookupField(SPECIES_DATA_CLASS, "speciesCode");
    private static final Method CLASS_GET_CLASS_DATA_METHOD = ReflectionUtil.lookupMethod(Class.class, "getClassData");
    private static final Class<?> MEMBER_NAME_CLASS = ReflectionUtil.lookupClass(false, "java.lang.invoke.MemberName");
    private static final Method MEMBER_NAME_GET_DECLARING_CLASS_METHOD = ReflectionUtil.lookupMethod(MEMBER_NAME_CLASS, "getDeclaringClass");
    private static final Method MEMBER_NAME_GET_NAME_METHOD = ReflectionUtil.lookupMethod(MEMBER_NAME_CLASS, "getName");
    private static final Method MEMBER_NAME_GET_METHOD_OR_FIELD_TYPE_METHOD = ReflectionUtil.lookupMethod(MEMBER_NAME_CLASS, "getMethodOrFieldType");
    private static final Method MEMBER_NAME_GET_REFERENCE_KIND_METHOD = ReflectionUtil.lookupMethod(MEMBER_NAME_CLASS, "getReferenceKind");
    private static final Class<?> METHOD_HANDLE_NATIVES_CLASS = ReflectionUtil.lookupClass(false, "java.lang.invoke.MethodHandleNatives");
    private static final Method METHOD_HANDLE_NATIVES_REF_KIND_NAME_METHOD = ReflectionUtil.lookupMethod(METHOD_HANDLE_NATIVES_CLASS, "refKindName", byte.class);
    private static final Class<?> LAMBDA_FORM_CLASS = ReflectionUtil.lookupClass(false, "java.lang.invoke.LambdaForm");
    private static final Field LAMBDA_FORM_CUSTOMIZED_FIELD = ReflectionUtil.lookupField(LAMBDA_FORM_CLASS, "customized");
    private static final Field LAMBDA_FORM_NAMES_FIELD = ReflectionUtil.lookupField(LAMBDA_FORM_CLASS, "names");
    private static final Class<?> BASIC_TYPE_CLASS = ReflectionUtil.lookupClass(false, "java.lang.invoke.LambdaForm$BasicType");
    private static final Class<?> NAME_CLASS = ReflectionUtil.lookupClass(false, "java.lang.invoke.LambdaForm$Name");
    private static final Field NAME_INDEX_FIELD = ReflectionUtil.lookupField(NAME_CLASS, "index");
    private static final Field NAME_CONSTRAINT_FIELD = ReflectionUtil.lookupField(NAME_CLASS, "constraint");
    private static final Field NAME_ARGUMENTS_FIELD = ReflectionUtil.lookupField(NAME_CLASS, "arguments");
    private static final Field NAME_FUNCTION_FIELD = ReflectionUtil.lookupField(NAME_CLASS, "function");
    private static final Class<?> NAMED_FUNCTION_CLASS = ReflectionUtil.lookupClass(false, "java.lang.invoke.LambdaForm$NamedFunction");
    private static final Field NAMED_FUNCTION_MEMBER_FIELD = ReflectionUtil.lookupField(NAMED_FUNCTION_CLASS, "member");
    private static final Method NAMED_FUNCTION_RESOLVED_HANDLE_METHOD = ReflectionUtil.lookupMethod(NAMED_FUNCTION_CLASS, "resolvedHandle");
    private static final Field FORM_FIELD = ReflectionUtil.lookupField(MethodHandle.class, "form");
    private static final Class<?> BOUND_METHOD_HANDLE_CLASS = ReflectionUtil.lookupClass(false, "java.lang.invoke.BoundMethodHandle");
    private static final Method BOUND_METHOD_HANDLE_SPECIES_DATA_METHOD = ReflectionUtil.lookupMethod(BOUND_METHOD_HANDLE_CLASS, "speciesData");
    private static final Class<?> DIRECT_METHOD_HANDLE_CLASS = ReflectionUtil.lookupClass(false, "java.lang.invoke.DirectMethodHandle");
    private static final Method DIRECT_METHOD_HANDLE_INTERNAL_MEMBER_NAME_METHOD = ReflectionUtil.lookupMethod(DIRECT_METHOD_HANDLE_CLASS, "internalMemberName");

    private static final String DMH_CLASS_NAME_SUBSTRING = "LambdaForm$DMH";
    private static final String DMH_STABLE_NAME_TEMPLATE = "Ljava/lang/invoke/LambdaForm$DMH.s";

    private static final String MH_CLASS_NAME_SUBSTRING = "LambdaForm$MH";
    private static final String MH_STABLE_NAME_TEMPLATE = "Ljava/lang/invoke/LambdaForm$MH.s";

    private static final String VH_CLASS_NAME_SUBSTRING = "LambdaForm$VH";
    private static final String VH_STABLE_NAME_TEMPLATE = "Ljava/lang/invoke/LambdaForm$VH.s";

    private final BigBang bb;

    private final ConcurrentMap<ResolvedJavaType, MethodHandleInvokerSubstitutionType> typeSubstitutions = new ConcurrentHashMap<>();
    private final Set<String> uniqueTypeNames = new HashSet<>();

    MethodHandleInvokerRenamingSubstitutionProcessor(BigBang bb) {
        this.bb = bb;
    }

    @Override
    public ResolvedJavaType lookup(ResolvedJavaType type) {
        if (!shouldReplace(type)) {
            return type;
        }
        return typeSubstitutions.computeIfAbsent(type, original -> getSubstitution(type, original));
    }

    public static boolean isMethodHandleType(ResolvedJavaType type) {
        String name = type.getName();
        return name.contains(DMH_CLASS_NAME_SUBSTRING) || name.contains(MH_CLASS_NAME_SUBSTRING) || name.contains(VH_CLASS_NAME_SUBSTRING);
    }

    private static boolean shouldReplace(ResolvedJavaType type) {
        return !(type instanceof MethodHandleInvokerSubstitutionType) && !(type instanceof BaseLayerType) && isMethodHandleType(type);
    }

    private MethodHandleInvokerSubstitutionType getSubstitution(ResolvedJavaType type, ResolvedJavaType original) {
        int hash;
        boolean isDirect = type.getName().contains(DMH_CLASS_NAME_SUBSTRING);
        try {
            Object lambdaForm;
            Object customizedMemberName = null;
            boolean customizedArbitraryMethodHandle = false;
            Class<?> clazz = OriginalClassProvider.getJavaClass(original);
            Object classData = CLASS_GET_CLASS_DATA_METHOD.invoke(clazz);
            if (LAMBDA_FORM_CLASS.isInstance(classData)) {
                lambdaForm = classData;
            } else if (classData instanceof List<?> list) {
                VMError.guarantee(list.size() > 1, "The classData cannot be a list with fewer than 2 elements.");
                lambdaForm = list.get(0);
                VMError.guarantee(LAMBDA_FORM_CLASS.isInstance(lambdaForm), "Expected classData to contain LambdaForm at the start of the list: %s", classData);
                if (isDirect) {
                    VMError.guarantee(list.size() == 2);
                    Object customizedHandle = list.get(1);
                    VMError.guarantee(DIRECT_METHOD_HANDLE_CLASS.isInstance(customizedHandle) && LAMBDA_FORM_CUSTOMIZED_FIELD.get(lambdaForm) == customizedHandle,
                                    "Expected classData to contain LambdaForm and its customization: %s", classData);

                    /*
                     * Two customized direct method handles with the same member would cause an
                     * aliasing issue. Avoiding it would require to disable method handle
                     * customization.
                     */
                    customizedMemberName = DIRECT_METHOD_HANDLE_INTERNAL_MEMBER_NAME_METHOD.invoke(customizedHandle);
                } else {
                    /*
                     * The classData array contains parts of the lambda form such as arguments,
                     * resolved method handle targets, classes used in type casts, and the
                     * LambdaForm itself, see callers of InvokerBytecodeGenerator.classData(Object).
                     * We only extract the LambdaForm from the classData as it contains the other
                     * objects, and they are included in our hash representation when we process the
                     * LambdaForm.
                     *
                     * When one of the arguments of a name is an arbitrary object, it is stored in
                     * the classData. The arguments are all checked later, and we throw if one
                     * argument is an arbitrary object.
                     *
                     * All resolvedHandles from the names that are not statically invocable are
                     * stored in the classData. We already recurse through all resolvedHandles.
                     *
                     * For each cast in the method handle, the class into which the object is cast
                     * is saved in the classData. The casts all depend on the types from the
                     * methodType and they are already included in the hash.
                     */
                    Object customizedHandle = LAMBDA_FORM_CUSTOMIZED_FIELD.get(lambdaForm);
                    if (customizedHandle != null) {
                        VMError.guarantee(customizedHandle == list.get(1), "Expected the customization to be right after the LambdaForm: %s", list.get(1));
                    }

                    /*
                     * Two customized arbitrary method handles with the same original lambda form
                     * will produce an aliasing issue. Avoiding it would require finding a way to
                     * distinguish them without using their lambda form as they are equal.
                     */
                    customizedArbitraryMethodHandle = true;
                }
            } else {
                throw VMError.shouldNotReachHere("Unexpected classData: %s", classData);
            }
            hash = computeLambdaFormHash(lambdaForm, isDirect);
            if (customizedMemberName != null) {
                /* MemberName.hashCode() also includes identity hash codes of Class<?> objects. */
                hash = hash * 31 + memberNameToString(customizedMemberName).hashCode();
            }
            if (customizedArbitraryMethodHandle) {
                hash = hash * 31 + "customized".hashCode();
            }
        } catch (ReflectiveOperationException e) {
            throw VMError.shouldNotReachHere(e);
        }
        boolean isVarHandle = type.getName().contains(VH_CLASS_NAME_SUBSTRING);
        return new MethodHandleInvokerSubstitutionType(original, findUniqueName(hash, isDirect, isVarHandle));
    }

    private int computeLambdaFormHash(Object lambdaForm, boolean isDirect) {
        /*
         * LambdaForm.hashCode() is not stable between image builds because it incorporates identity
         * hash codes of objects such as those of Class<?> that don't override hashCode(). For that
         * reason, we compute a hash code from LambdaForm.toString(). It might also not be perfectly
         * unique because the string contains unqualified class names and can contain string
         * representations of constraints that may be arbitrary objects, but it should typically be
         * distinct and stable.
         */
        int hash;
        if (isDirect) {
            hash = lambdaForm.toString().hashCode();
        } else {
            try {
                hash = getUniqueStableHash(lambdaForm);
            } catch (ReflectiveOperationException e) {
                throw VMError.shouldNotReachHere(e);
            }
        }
        return hash;
    }

    /**
     * Before recursively computing the hash of the inner method handles, the parts of the lambda
     * form string representation that would cause an unstable name have to be replaced. Various
     * assertion are run on the lambda form to ensure the name is stable between images and that two
     * different lambda form cannot have the same name.
     */
    private int getUniqueStableHash(Object lambdaForm) throws ReflectiveOperationException {
        String lambdaFormString = lambdaForm.toString();
        int hash = 0;

        Object names = LAMBDA_FORM_NAMES_FIELD.get(lambdaForm);
        int namesLength = Array.getLength(names);
        for (int i = 0; i < namesLength; ++i) {
            Object name = Array.get(names, i);
            /*
             * A LambdaForm$Name.toString without an index uses its identity hash code, which is not
             * stable between two different JVM instances.
             */
            assert NAME_INDEX_FIELD.getShort(name) >= 0 : "The name " + name + " from the lambda form " + lambdaForm + " has no index set, which produces unstable names.";

            Object constraint = NAME_CONSTRAINT_FIELD.get(name);
            if (constraint != null) {
                if (constraint instanceof Class<?> classConstraint) {
                    /*
                     * If the constraint is a class, the Name.paramString uses its simple name. To
                     * avoid potential aliasing, the hash of the qualified name is mixed in the
                     * result.
                     */
                    hash = hash * 31 + classConstraint.getName().hashCode();
                } else if (SPECIES_DATA_CLASS.isInstance(constraint)) {
                    hash = hash * 31 + getSpeciesDataHash(constraint);
                } else {
                    throw new AssertionError("The name " + name + " has a constraint that could cause an unstable name: " + constraint);
                }
            }

            Object arguments = NAME_ARGUMENTS_FIELD.get(name);
            if (arguments != null) {
                int argumentsLength = Array.getLength(arguments);
                for (int j = 0; j < argumentsLength; ++j) {
                    Object argument = Array.get(arguments, j);
                    if (argument != null && !(argument instanceof Integer) && !NAME_CLASS.isInstance(argument) && !BASIC_TYPE_CLASS.arrayType().isInstance(argument)) {
                        throw new AssertionError("Lambda form argument " + argument + " is of type " + argument.getClass() + " which might produce unstable name.");
                    }

                    /*
                     * An argument can be a BasicType[]. In this case, the toString method is used
                     * by Name.exprString and the resulting string representation of the array
                     * contains its identity hash code (but not its elements). This will cause
                     * unstable name between two JVM instances. To solve this, the string
                     * representation of the BasicType[] is replaced by a custom string
                     * representation (see getBasicTypeArrayString).
                     */
                    if (BASIC_TYPE_CLASS.arrayType().isInstance(argument)) {
                        lambdaFormString = lambdaFormString.replace(String.valueOf(argument), Arrays.toString((Object[]) argument));
                    }
                }
            }

            Object function = NAME_FUNCTION_FIELD.get(name);
            if (function != null) {
                Object member = NAMED_FUNCTION_MEMBER_FIELD.get(function);
                if (member != null) {
                    /*
                     * The LambdaForm$NamedFunction.toString uses the simple name of the member
                     * class. To avoid potential aliasing, the hash of the member is mixed in the
                     * result.
                     */
                    hash = hash * 31 + memberNameToString(member).hashCode();
                }
                /*
                 * The method handle of the NamedFunction is used in the string representation. To
                 * avoid potential aliasing, the hash of the descriptor string is mixed in the
                 * result.
                 */
                Object resolvedHandle = NAMED_FUNCTION_RESOLVED_HANDLE_METHOD.invoke(function);
                MethodType methodType = ((MethodHandle) resolvedHandle).type();
                hash = hash * 31 + methodType.descriptorString().hashCode();

                if (BOUND_METHOD_HANDLE_CLASS.isInstance(resolvedHandle)) {
                    /*
                     * BoundMethodHandle.internalValues calls BoundMethodHandle.arg, which retrieves
                     * the object that was bound to the corresponding argument, and return its
                     * string representation. This method is only used if the debug method handle
                     * names are activated. The object used may not have a stable string
                     * representation, which would lead to an unstable name.
                     */
                    assert !DEBUG_METHOD_HANDLE_NAMES_FIELD.getBoolean(null) : "The method handle " + resolvedHandle +
                                    " with debug method handle names can contain the string representation from any object, which would cause the name to be unstable.";

                    /*
                     * Without the debug method handle names, the MethodHandle.toString method does
                     * not include any additional detail if the method handle is a bound method
                     * handle. To avoid potential aliasing, the custom hash of the species data is
                     * mixed with the result.
                     */
                    Object speciesData = BOUND_METHOD_HANDLE_SPECIES_DATA_METHOD.invoke(resolvedHandle);
                    hash = hash * 31 + getSpeciesDataHash(speciesData);
                }

                Object innerLambdaForm = FORM_FIELD.get(resolvedHandle);
                hash = hash * 31 + computeLambdaFormHash(innerLambdaForm, DIRECT_METHOD_HANDLE_CLASS.isInstance(resolvedHandle));
            }
        }
        return hash * 31 + lambdaFormString.hashCode();
    }

    /**
     * The method MemberName.toString uses the Object.toString method which can produce an unstable
     * name. This method uses the qualified names of the classes instead of the simple names.
     */
    private static String memberNameToString(Object memberName) throws ReflectiveOperationException {
        Class<?> declaringClass = (Class<?>) MEMBER_NAME_GET_DECLARING_CLASS_METHOD.invoke(memberName);
        String name = (String) MEMBER_NAME_GET_NAME_METHOD.invoke(memberName);
        MethodType methodType = (MethodType) MEMBER_NAME_GET_METHOD_OR_FIELD_TYPE_METHOD.invoke(memberName);
        byte refKind = (byte) MEMBER_NAME_GET_REFERENCE_KIND_METHOD.invoke(memberName);
        String refKindName = (String) METHOD_HANDLE_NATIVES_REF_KIND_NAME_METHOD.invoke(null, refKind);
        return declaringClass.getName() + name + methodType.descriptorString() + refKindName;
    }

    /**
     * The SpeciesData.toString method uses the simple name of the metaType and speciesCode classes.
     * To avoid potential aliasing, the qualified names of the two classes are combined and hashed.
     */
    private static int getSpeciesDataHash(Object speciesData) throws ReflectiveOperationException {
        Object classSpecializer = SPECIES_DATA_OUTER_METHOD.invoke(speciesData);
        Class<?> metaType = (Class<?>) CLASS_SPECIALIZER_META_TYPE_FIELD.get(classSpecializer);
        Class<?> speciesCode = (Class<?>) SPECIES_DATA_SPECIES_CODE_FIELD.get(speciesData);
        return metaType.getName().hashCode() * 31 + speciesCode.getName().hashCode();
    }

    private String findUniqueName(int hashCode, boolean isDirect, boolean isVarHandle) {
        String baseName = Integer.toHexString(hashCode);
        if (isDirect) {
            baseName = DMH_STABLE_NAME_TEMPLATE + baseName;
        } else if (isVarHandle) {
            baseName = VH_STABLE_NAME_TEMPLATE + baseName;
        } else {
            baseName = MH_STABLE_NAME_TEMPLATE + baseName;
        }
        String name = baseName + ";";
        synchronized (uniqueTypeNames) {
            int suffix = 1;
            while (uniqueTypeNames.contains(name)) {
                name = baseName + "_" + suffix + ";";
                suffix++;
            }
            uniqueTypeNames.add(name);
            return name;
        }
    }

    public boolean isNameAlwaysStable(String methodHandleName) {
        int lastIndex = methodHandleName.lastIndexOf('_');
        if (lastIndex < 0) {
            return true;
        }
        return !uniqueTypeNames.contains(methodHandleName.substring(0, lastIndex) + "_1;");
    }

    boolean checkAllTypeNames() {
        if (!SubstrateUtil.assertionsEnabled()) {
            throw new AssertionError("Expensive check: should only run with assertions enabled.");
        }

        List<AnalysisType> types = bb.getUniverse().getTypes();

        if (types.stream().anyMatch(aType -> shouldReplace(aType.getWrapped()))) {
            throw new AssertionError("All relevant types must have been substituted.");
        }

        Set<String> names = new HashSet<>();
        types.stream()
                        .filter(MethodHandleInvokerRenamingSubstitutionProcessor::isMethodHandleType)
                        .map(AnalysisType::getName)
                        .forEach(name -> {
                            if (names.contains(name)) {
                                throw new AssertionError("Duplicate name: " + name);
                            }
                            names.add(name);
                        });
        return true;
    }
}
